"""cnn_segment.py
Recurrent cnn models.
"""

import torch.nn as nn


class CNN(nn.Module):
    def __init__(self, width=64, depth=4, in_channels=3):
        super().__init__()
        self.width = width
        self.depth = depth
        self.in_channels = in_channels
        self.first_layers = nn.Sequential(nn.Conv2d(self.in_channels, int(self.width/2),
                                                    kernel_size=3, stride=1, padding=1),
                                          nn.ReLU(),
                                          nn.Conv2d(int(self.width/2), self.width, kernel_size=3,
                                                    stride=1, padding=1),
                                          nn.ReLU())
        self.middle_layers = nn.Sequential(*[nn.Sequential(nn.Conv2d(self.width, self.width,
                                                                     kernel_size=3, stride=1,
                                                                     padding=1), nn.ReLU())
                                             for _ in range(depth - 3)])
        self.last_layers = nn.Sequential(nn.Conv2d(self.width, 2, kernel_size=3, stride=1,
                                                   padding=1))

    def forward(self, x):
        out = self.first_layers(x)
        out = self.middle_layers(out)
        out = self.last_layers(out)
        return out


def cnn_4():
    return CNN(depth=4, width=256)


def cnn_5():
    return CNN(depth=5, width=256)


def cnn_6():
    return CNN(depth=6, width=256)


def cnn_7():
    return CNN(depth=7, width=256)


def cnn_8():
    return CNN(depth=8, width=256)
